<!-- Improved compatibility of back to top link: See: https://github.com/othneildrew/Best-README-Template/pull/73 -->
<a name="readme-top"></a>
<!--
*** Thanks for checking out the Best-README-Template. If you have a suggestion
*** that would make this better, please fork the repo and create a pull request
*** or simply open an issue with the tag "enhancement".
*** Don't forget to give the project a star!
*** Thanks again! Now go create something AMAZING! :D
-->



<!-- PROJECT LOGO -->
<br />
<div align="center">


  <h1 align="center">Generator Born from Classifier</h1>

  <p align="center">
    Learn a generator directly from a pre-trained classifier, <em><strong>without</strong></em> the assistance of any training data.
    <!-- <br />
    <a href="https://github.com/othneildrew/Best-README-Template"><strong>Explore the docs »</strong></a>
    <br />
    <br />
    <a href="https://github.com/othneildrew/Best-README-Template">View Demo</a>
    ·
    <a href="https://github.com/othneildrew/Best-README-Template/issues">Report Bug</a>
    ·
    <a href="https://github.com/othneildrew/Best-README-Template/issues">Request Feature</a> -->
  </p>
</div>



<!-- TABLE OF CONTENTS -->
<details>
  <summary style="font-size: 20px;">Table of Contents</summary>
  <ol>
    <li><a href="#environment-setup">Environment Setup</a></li>
    <li><a href="#data-preparation">Data Preparation</a></li>
    <li><a href="#pre-trained-classifier">Pre-trained Classifier</a></li>
    <li><a href="#traing-configuration">Traing Configuration</a></li>
    <li><a href="#run-the-code">Run the Code</a></li>
  </ol>
</details>

## Environment Setup
Our code is implemented using python 3.8,7, pytorch 11.1.0 and cuda 11.7. All the relied packages are listed in the `environment.yaml` file. Run the following code, which will creates and activates a virual environment named *ctog*.

```bash
conda env create -f environment.yaml
conda activate ctog
```
<p align="right">(<a href="#readme-top">back to top</a>)</p>

## Data Preparation
The default data format is **torch.utils.data.Dataset**, and the code supports loading the dataset directly using **torch.load**. To add new datasets, modify the `src/dataset/dataset.py` file. In order to facilitate the configuration of network structures, the following three attributes should be added to the dataset:

- **input_shape \<list\>**: This attribute represents the shape of the input data. It describes the dimensions of the input, such as the channel, height and width.
- **num_classes \<int\>**: This attribute indicates the number of classes or categories in the dataset. 
- **num_channels: \<int\>**: This attribute represents the number of channels in the input data, used to specify the number of color channels, such as RGB images having 3 channels.

<p align="right">(<a href="#readme-top">back to top</a>)</p>


## Pre-trained Classifier
Our method supports all neural networks that satisfy the definition of the quasi-homogeneous model, *e.g.*, linear networks with (or without) bias terms, normalization layers, and residual connections. The code implementation supports neural networks that inherit from the **torch.nn.Module** class.

To add new network structures, you can modify the `src/models/classifier.py` file. 

<p align="right">(<a href="#readme-top">back to top</a>)</p>

## Traing Configuration

Our training script allows the input of hyperparameters and other configuration information through two methods. First, some relatively fixed parameter configurations are stored in a `.yaml` file. Second, varying or to-be-optimzed hyperparameters can be passed as command-line arguments when invoking the script. All predefined parameters are listed in the `src/configs/base_config.py` file. You can modify this file to add new parameters or change the definitions of existing parameters. When writing a `.yaml` file, it's necessary to import the `src/configs/base_config.py` file first. Here's an example of a usable `.yaml` file:

```yaml
_BASE_: "../_base_config.yaml"
action: "train_generator"
Data:
  Name: "MNIST"
  path: "path/to/data.pt"
Classifier:
  Name: "LinearNet"
  hidden_dims: [512,512]
  activation_name: "ReLU"
  load_path: "path/to/pre-trained/classifier.pt"
Generator:
  Name : "LinearG"
  batch_size : 500
  lr : 1e-3
  weight_decay : 1e-4
  epoches : 5000
  evaluation_interval : 10
  num_latent_feature : 128
  num_class_embedding : 128
  hidden_dims : [512,512]
```

<p align="right">(<a href="#readme-top">back to top</a>)</p>

## Run the Code

After completing the above configurations, run `main.py` to train the generator. Here is a sample launching command available for use:

```sh
cd .

CUDA_VISIBLE_DEVICES=0 python main.py \
  --config-file "path/to/the/configuration/file.yaml" \
  Generator.lagrange_coe 1.0 \
  Generator.duality_coe 1.0 \
  Generator.regularization_coe 1.0
```

<p align="right">(<a href="#readme-top">back to top</a>)</p>